import torch
import torch.nn as nn
from IPython import embed

class DeepSet(nn.Module):
    def __init__(self, n_players, embed_dim=2, linear_dim=10):
        super(DeepSet, self).__init__()
        self.embed_dim = embed_dim
        self.embedding = nn.Embedding(n_players, embed_dim)

        self.linear1 = nn.Linear(embed_dim, linear_dim)
        self.linear2 = nn.Linear(linear_dim, 1)
        
        self.bias = nn.Embedding(n_players, 1)
        self.ReLU = nn.ReLU()
        self.sigmoid = nn.Sigmoid()

    def forward(self, winners, losers):

        winner_score = self.bias(winners).sum(1)
        loser_score = self.bias(losers).sum(1)
       
        winner_score += self.linear2(self.ReLU(self.linear1(self.embedding(winners).sum(1))))
        loser_score += self.linear2(self.ReLU(self.linear1(self.embedding(losers)).sum(1)))

        return self.sigmoid(winner_score - loser_score)

class DeepSet_single(nn.Module):
    def __init__(self, n_players, embed_dim=10, linear_dim=10):
        super(DeepSet_single, self).__init__()
        self.embed_dim = embed_dim

        self.embedding = nn.Linear(n_players, embed_dim)
        self.linear1 = nn.Linear(embed_dim, linear_dim)
        self.linear2 = nn.Linear(linear_dim, 1)
        
        self.bias = nn.Linear(n_players, 1)
        self.ReLU = nn.ReLU()

    def forward(self, winners):

        winners = winners.float()
        winner_score = self.bias(winners)
       
        winner_score += self.linear2(self.ReLU(self.linear1(self.embedding(winners))))

        return winner_score

class LR(nn.Module):
    def __init__(self, n_players):
        super(LR, self).__init__()
        self.bias = nn.Embedding(n_players, 1)
        
        self.sigmoid = nn.Sigmoid()

    def forward(self, team1, team2):

        winner_score = self.bias(team1).sum(1)
        loser_score = self.bias(team2).sum(1)

        return self.sigmoid(winner_score - loser_score)

class LR_single(nn.Module):
    def __init__(self, n_players):
        super(LR_single, self).__init__()
        
        self.bias = nn.Linear(n_players, 1)
        
    def forward(self, team1):

        winner_score = self.bias(team1.float())

        return winner_score


class FHoi(nn.Module):
    def __init__(self, n_players):
        super(FHoi, self).__init__()

        self.n_players = n_players

        self.one_hot = nn.Embedding(self.n_players, self.n_players) 
        self.one_hot.weight.requires_grad = False
        self.one_hot.weight.copy_(torch.eye(self.n_players))

        self.weight = nn.Linear(n_players, n_players)

        self.sigmoid = nn.Sigmoid()

    def forward(self, team1, team2):

        idx1 = self.one_hot(team1).sum(1)
        prod1 = self.weight(idx1)
        score1 = torch.matmul(idx1.unsqueeze(1), prod1.unsqueeze(2)).view(-1, 1)

        idx2 = self.one_hot(team2).sum(1)
        prod2 = self.weight(idx2)
        score2 = torch.matmul(idx2.unsqueeze(1), prod2.unsqueeze(2)).view(-1, 1)

        return self.sigmoid(score1 - score2)
        
class FHoi_single(nn.Module):
    def __init__(self, n_players):
        super(FHoi_single, self).__init__()

        self.n_players = n_players

        self.weight = nn.Linear(n_players, n_players, bias=False)
        self.bias = nn.Linear(n_players, 1, bias=False)

        self.diag_mask = 1 - torch.eye(self.n_players)
        self.diag_mask.requires_grad = False

        self.weight.weight.data = self.weight.weight.data * self.diag_mask

    def forward(self, idx1, sym=True, mask=True):

        idx1 = idx1.float()
        score1 = self.bias(idx1)

        if sym:
            w = self.weight.weight + self.weight.weight.t()
            if mask: w = w * self.diag_mask

            prod1 = torch.matmul(idx1, w)
            score2 = torch.matmul(idx1.unsqueeze(1), prod1.unsqueeze(2)).view(-1, 1)
        else:
            if mask:
                w = self.weight.weight * self.diag_mask 
                prod1 = torch.matmul(idx1, w)
            else:
                prod1 = self.weight(idx1)

            score2 = torch.matmul(idx1.unsqueeze(1), prod1.unsqueeze(2)).view(-1, 1)

        return score1 + score2
